#include "opencl_source_map.hpp" 
namespace MNN { 
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* layernorm_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void layernorm_buf(__private int global_dim0,__private int global_dim1,\n"
" __global const FLOAT*input,\n"
" __global FLOAT*output,\n"
" __private const int inside,\n"
"#ifdef GAMMA_BETA\n"
" __global const FLOAT *gamma,\n"
" __global const FLOAT *beta,\n"
"#endif\n"
" __private float epsilon){\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
"#if LOCAL_SIZE>1\n"
" float local sum_mnn[LOCAL_SIZE];\n"
" #ifndef RMSNORM\n"
" float local sum_mean_mnn[LOCAL_SIZE];\n"
" #endif\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int lid=get_local_id(0);\n"
" const int offset=pos.y*inside;\n"
" const int inside_v4=(inside+3) >> 2;\n"
" #ifdef PACK_LEAVE\n"
" const int loop=inside_v4-1;\n"
" const int inside_remain=inside-((inside_v4-1) << 2);\n"
" #else\n"
" const int loop=inside_v4;\n"
" #endif\n"
" \n"
" float4 in_sum=0;\n"
" int index=lid;\n"
" #ifdef RMSNORM\n"
" float4 mean=(float4)0;\n"
" #else\n"
" for(; index<loop; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" in_sum += in;\n"
" }\n"
" sum_mean_mnn[lid]=in_sum.x+in_sum.y+in_sum.z+ in_sum.w;\n"
" \n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i){\n"
" float in=input[offset+index*4+i];\n"
" sum_mean_mnn[lid]=sum_mean_mnn[lid]+in;\n"
" }\n"
" }\n"
" #endif\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum_mean_mnn[lid]=sum_mean_mnn[lid]+sum_mean_mnn[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=sum_mean_mnn[0]/(float4)inside;\n"
" #endif\n"
" in_sum=0;\n"
" index=lid;\n"
" for(; index<loop; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" sum_mnn[lid]=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i) {\n"
" float in=input[offset+index*4+i];\n"
" in=(in-mean.x)*(in-mean.x);\n"
" sum_mnn[lid]=sum_mnn[lid]+in;\n"
" }\n"
" }\n"
" #endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum_mnn[lid]=sum_mnn[lid]+sum_mnn[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=sum_mnn[0]/(float4)inside;\n"
" float4 value=(float4)1.0f/(float4)sqrt(square_sum+(float4)epsilon);\n"
" index=lid;\n"
" for(; index<loop; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" #ifdef GAMMA_BETA\n"
" float4 out=(in-mean)*value*convert_float4(vload4(index,gamma))+convert_float4(vload4(index,beta));\n"
" #else\n"
" float4 out=(in-mean)*value;\n"
" #endif\n"
" vstore4(CONVERT_FLOAT4(out),index,output+offset);\n"
" }\n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i){\n"
" float in=input[offset+index*4+i];\n"
" #ifdef GAMMA_BETA\n"
" float out=(in-mean.x)*value.x*(float)gamma[index*4+i]+(float)beta[index*4+i];\n"
" #else\n"
" float out=(in-mean.x)*value.x;\n"
" #endif\n"
" output[offset+index*4+i]=out;\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
"#else\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int offset=pos.y*inside;\n"
" float in_sum=0;\n"
" #ifdef RMSNORM\n"
" float mean=0;\n"
" #else\n"
" for(int index=0; index<inside; index++){\n"
" in_sum += (float)input[offset+index];\n"
" }\n"
" float mean=in_sum/inside;\n"
" #endif\n"
" in_sum=0;\n"
" for(int index=0; index<inside; index++){\n"
" float in=(float)input[offset+index];\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" float square_sum=in_sum/inside;\n"
" float value=1.0f/sqrt(square_sum+epsilon);\n"
" for(int i=0; i<inside; ++i){\n"
" float in=input[offset+i];\n"
" #ifdef GAMMA_BETA\n"
" float out=(in-mean)*value*(float)gamma[i]+(float)beta[i];\n"
" #else\n"
" float out=(in-mean)*value;\n"
" #endif\n"
" output[offset+i]=out;\n"
" }\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
}
